{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# To view this dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /home/lilin/model/model_mnist/train-images-idx3-ubyte.gz\n",
      "Extracting /home/lilin/model/model_mnist/train-labels-idx1-ubyte.gz\n",
      "Extracting /home/lilin/model/model_mnist/t10k-images-idx3-ubyte.gz\n",
      "Extracting /home/lilin/model/model_mnist/t10k-labels-idx1-ubyte.gz\n",
      "('training data size: ', 55000)\n",
      "('Validating data size: ', 5000)\n",
      "('Test data size: ', 10000)\n",
      "('example training data:', array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.38039219,  0.37647063,  0.3019608 ,\n",
      "        0.46274513,  0.2392157 ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.35294119,  0.5411765 ,  0.92156869,\n",
      "        0.92156869,  0.92156869,  0.92156869,  0.92156869,  0.92156869,\n",
      "        0.98431379,  0.98431379,  0.97254908,  0.99607849,  0.96078438,\n",
      "        0.92156869,  0.74509805,  0.08235294,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.54901963,\n",
      "        0.98431379,  0.99607849,  0.99607849,  0.99607849,  0.99607849,\n",
      "        0.99607849,  0.99607849,  0.99607849,  0.99607849,  0.99607849,\n",
      "        0.99607849,  0.99607849,  0.99607849,  0.99607849,  0.99607849,\n",
      "        0.74117649,  0.09019608,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.88627458,  0.99607849,  0.81568635,\n",
      "        0.78039223,  0.78039223,  0.78039223,  0.78039223,  0.54509807,\n",
      "        0.2392157 ,  0.2392157 ,  0.2392157 ,  0.2392157 ,  0.2392157 ,\n",
      "        0.50196081,  0.8705883 ,  0.99607849,  0.99607849,  0.74117649,\n",
      "        0.08235294,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.14901961,  0.32156864,  0.0509804 ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.13333334,\n",
      "        0.83529419,  0.99607849,  0.99607849,  0.45098042,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.32941177,  0.99607849,\n",
      "        0.99607849,  0.91764712,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.32941177,  0.99607849,  0.99607849,  0.91764712,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.41568631,  0.6156863 ,\n",
      "        0.99607849,  0.99607849,  0.95294124,  0.20000002,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.09803922,  0.45882356,  0.89411771,  0.89411771,\n",
      "        0.89411771,  0.99215692,  0.99607849,  0.99607849,  0.99607849,\n",
      "        0.99607849,  0.94117653,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.26666668,  0.4666667 ,  0.86274517,\n",
      "        0.99607849,  0.99607849,  0.99607849,  0.99607849,  0.99607849,\n",
      "        0.99607849,  0.99607849,  0.99607849,  0.99607849,  0.55686277,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.14509805,  0.73333335,\n",
      "        0.99215692,  0.99607849,  0.99607849,  0.99607849,  0.87450987,\n",
      "        0.80784321,  0.80784321,  0.29411766,  0.26666668,  0.84313732,\n",
      "        0.99607849,  0.99607849,  0.45882356,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.44313729,  0.8588236 ,  0.99607849,  0.94901967,  0.89019614,\n",
      "        0.45098042,  0.34901962,  0.12156864,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.7843138 ,  0.99607849,  0.9450981 ,\n",
      "        0.16078432,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.66274512,  0.99607849,\n",
      "        0.6901961 ,  0.24313727,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.18823531,\n",
      "        0.90588242,  0.99607849,  0.91764712,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.07058824,  0.48627454,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.32941177,  0.99607849,  0.99607849,\n",
      "        0.65098041,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.54509807,  0.99607849,  0.9333334 ,  0.22352943,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.82352948,  0.98039222,  0.99607849,\n",
      "        0.65882355,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.94901967,  0.99607849,  0.93725497,  0.22352943,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.34901962,  0.98431379,  0.9450981 ,\n",
      "        0.33725491,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.01960784,\n",
      "        0.80784321,  0.96470594,  0.6156863 ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.01568628,  0.45882356,  0.27058825,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,\n",
      "        0.        ,  0.        ,  0.        ,  0.        ], dtype=float32))\n",
      "('example training data: ', array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.]))\n"
     ]
    }
   ],
   "source": [
    "#to view this dataset\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "mnist=input_data.read_data_sets(\"/home/lilin/model/model_mnist/\",one_hot=True)\n",
    "\n",
    "#print training data size\n",
    "print(\"training data size: \",mnist.train.num_examples)\n",
    "\n",
    "#print validating dataset size\n",
    "print(\"Validating data size: \",mnist.validation.num_examples)\n",
    "\n",
    "#print test dataset size\n",
    "print(\"Test data size: \",mnist.test.num_examples)\n",
    "\n",
    "#print example training data:\n",
    "print(\"example training data:\",mnist.train.images[0])\n",
    "\n",
    "#print example training data label\n",
    "print(\"example training data: \",mnist.train.labels[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "0.9244\n"
     ]
    }
   ],
   "source": [
    "# no hidden layer to mnist(model_02)\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\",one_hot=True)\n",
    "\n",
    "import tensorflow as tf \n",
    "\n",
    "sess=tf.InteractiveSession()\n",
    "x=tf.placeholder(tf.float32,[None,784])\n",
    "\n",
    "w=tf.Variable(tf.zeros([784,10]))\n",
    "b=tf.Variable(tf.zeros([10]))\n",
    "\n",
    "y=tf.nn.softmax(tf.matmul(x,w)+b)\n",
    "\n",
    "y_=tf.placeholder(tf.float32,[None,10])\n",
    "cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))\n",
    "\n",
    "train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n",
    "\n",
    "tf.global_variables_initializer().run()\n",
    "\n",
    "for i in range(10000):\n",
    "    batch_xs, batch_ys = mnist.train.next_batch(100)\n",
    "    train_step.run({x:batch_xs,y_:batch_ys})\n",
    "    \n",
    "correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))\n",
    "\n",
    "accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))\n",
    "print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))\n",
    "#ji suan tu xie ru rizi\n",
    "writer=tf.summary.FileWriter(\"/home/lilin/model/model_02\",tf.get_default_graph())\n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "0.9757\n"
     ]
    }
   ],
   "source": [
    "#tesnsorflow one hidden layer net to mnist multi-sensor\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import tensorflow as tf\n",
    "\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\",one_hot=True)\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "in_units = 784\n",
    "h1_units = 300\n",
    "w1 = tf.Variable(tf.truncated_normal([in_units,h1_units],stddev=0.1))\n",
    "b1 = tf.Variable(tf.zeros([h1_units]))\n",
    "w2 = tf.Variable(tf.zeros([h1_units,10]))\n",
    "b2 = tf.Variable(tf.zeros([10]))\n",
    "#define input x and drop_rate\n",
    "x = tf.placeholder(tf.float32,[None,in_units])\n",
    "keep_prob = tf.placeholder(tf.float32)\n",
    "#define hidden1\n",
    "hidden1 = tf.nn.relu(tf.matmul(x,w1)+b1)\n",
    "hidden1_drop = tf.nn.dropout(hidden1,keep_prob)\n",
    "y = tf.nn.softmax(tf.matmul(hidden1_drop,w2)+b2)\n",
    "y_ = tf.placeholder(tf.float32,[None,10])\n",
    "\n",
    "cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))\n",
    "train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy)\n",
    "\n",
    "tf.global_variables_initializer().run()\n",
    "for i in range(3000):\n",
    "    batch_xs, batch_ys = mnist.train.next_batch(100)\n",
    "    train_step.run({x:batch_xs,y_:batch_ys,keep_prob:0.5})\n",
    "    \n",
    "correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))\n",
    "\n",
    "print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "step 0, training accuracy 0.04\n",
      "step 1000, training accuracy 0.98\n",
      "step 2000, training accuracy 0.98\n",
      "step 3000, training accuracy 0.98\n",
      "step 4000, training accuracy 1\n",
      "step 5000, training accuracy 1\n",
      "step 6000, training accuracy 0.98\n",
      "step 7000, training accuracy 1\n",
      "step 8000, training accuracy 0.98\n",
      "step 9000, training accuracy 1\n",
      "step 10000, training accuracy 1\n",
      "step 11000, training accuracy 1\n",
      "step 12000, training accuracy 0.98\n",
      "step 13000, training accuracy 1\n",
      "step 14000, training accuracy 1\n",
      "step 15000, training accuracy 1\n",
      "step 16000, training accuracy 1\n",
      "step 17000, training accuracy 1\n",
      "step 18000, training accuracy 1\n",
      "step 19000, training accuracy 1\n",
      "test accuracy 0.9926\n"
     ]
    }
   ],
   "source": [
    "#use tensorflow convolution Lenet to mnist\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import tensorflow as tf\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\",one_hot=True)\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "#define weight bias conv and pool\n",
    "def weight_variable(shape):\n",
    "    initial = tf.truncated_normal(shape,stddev=0.1)\n",
    "    return tf.Variable(initial)\n",
    "def bias_variable(shape):\n",
    "    initial = tf.constant(0.1,shape=shape)\n",
    "    return tf.Variable(initial)\n",
    "def conv2d(x,w):\n",
    "    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')\n",
    "def max_pool_2x2(x):\n",
    "    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')\n",
    "\n",
    "#define input and output holder\n",
    "x = tf.placeholder(tf.float32,[None,784])\n",
    "y_ = tf.placeholder(tf.float32,[None,10])\n",
    "x_image = tf.reshape(x,[-1,28,28,1])\n",
    "\n",
    "#define conv1\n",
    "w_conv1 = weight_variable([5,5,1,32])\n",
    "b_conv1 = bias_variable([32])\n",
    "h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)\n",
    "h_pool1 = max_pool_2x2(h_conv1)\n",
    "\n",
    "#define conv2\n",
    "w_conv2 = weight_variable([5,5,32,64])\n",
    "b_conv2 = bias_variable([64])\n",
    "h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)\n",
    "h_pool2 = max_pool_2x2(h_conv2)\n",
    "#defeine fc1\n",
    "w_fc1 = weight_variable([7*7*64,1024])\n",
    "b_fc1 = bias_variable([1024])\n",
    "h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])\n",
    "h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)\n",
    "\n",
    "#define dropout\n",
    "keep_prob = tf.placeholder(tf.float32)\n",
    "h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)\n",
    "\n",
    "#define softmax\n",
    "w_fc2 = weight_variable([1024,10])\n",
    "b_fc2 = bias_variable([10])\n",
    "y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)\n",
    "\n",
    "#define loss function:cross entropy\n",
    "cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))\n",
    "train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)\n",
    "#define accuracy\n",
    "correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))\n",
    "\n",
    "#train net\n",
    "\n",
    "tf.global_variables_initializer().run()\n",
    "for i in range(20000):\n",
    "    batch = mnist.train.next_batch(50)\n",
    "    if i%1000 == 0:\n",
    "        train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})\n",
    "        print(\"step %d, training accuracy %g\"%(i,train_accuracy))\n",
    "    train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})\n",
    "print(\"test accuracy %g\"%accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))\n",
    "\n",
    "# write the Graph to file\n",
    "writer=tf.summary.FileWriter(\"/home/lilin/model/model_03\",tf.get_default_graph())\n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
